--- title: Categorical DQN keywords: fastai sidebar: home_sidebar summary: "An implimentation of a DQN that uses distributions to represent Q from the paper A Distributional Perspective on Reinforcement Learning" description: "An implimentation of a DQN that uses distributions to represent Q from the paper A Distributional Perspective on Reinforcement Learning" nb_path: "nbs/10e_agents.dqn.categorical.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

The Categorical DQN can be summarized as:

Instead of action outputs being single Q values, they are instead distributions of `N` size.

We start off with the idea of atoms and supports. A support acts as a mask over the output action distributions. This is illistrated by the equations and the corresponding functions.

We start with the equation...

$$ {\large Z_{\theta}(z,a) = z_i \quad w.p. \: p_i(x,a):= \frac{ e^{\theta_i(x,a)}} {\sum_j{e^{\theta_j(x,a)}}} } $$

... which shows that the end of our neural net model needs to be squished to be a proper probability. It also defines $z_i$ which is a support of which, we will define very soon. Below is the implimentation of the right side equation for $p_i(x,a)$

An important note is that $\frac{ e^{\theta_i(x,a)}} {\sum_j{e^{\theta_j(x,a)}}} $ is just:

{% raw %}
Softmax
torch.nn.modules.activation.Softmax
{% endraw %}

We pretend that the output of the neural net is of shape (batch_sz,n_actions,n_atoms). In this instance, there is only one action. This implies that $Z_{\theta}$ is just $z_0$.

{% raw %}
out=Softmax(dim=1)(torch.randn(1,51,1))[0] # Action 0
plt.plot(out.numpy())
[<matplotlib.lines.Line2D at 0x7f5995d63730>]
{% endraw %}

The next function describes how propabilities are calculated from the neural net output. The equation describes a $z_i$ which is explained by: $$ \{z_i = V_{min} + i\Delta z : 0 \leq i < N \}, \: \Delta z := \frac{V_{max} - V_{min}}{N - 1} $$

Where $V_{max}$, $V_{min}$, and $N$ are constants that we define. Note that $N$ is the number of atoms. So what does a $z_i$ look like? We will define this in code below...

{% raw %}

create_support[source]

create_support(v_min=-10, v_max=10, n_atoms=51)

Creates the support and returns the z_delta that was used.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
import matplotlib.pyplot as plt

support_dist,z_delta=create_support()
print('z_delta: ',z_delta)
plt.plot(support_dist.numpy())
z_delta:  0.4
[<matplotlib.lines.Line2D at 0x7f5995c61610>]
{% endraw %}

This is a single $z_i$ in $Z_{\theta}$. The number of $z_i$s is equal to the number of actions that the DQN is operating with. {% include note.html content='Josiah: Is this always the case? Could there be only $z_0$ and multiple actions?' %} Ok! Hopefully this wasn't too bad to go through. We basically normalized the neural net output to be nicer to deal with, and created/initialized a (bunch) of increasing arrays that we are calling discrete distributions i.e. output from create_support.

Now for the fun part! We have this giant ass update equation:

$$ {\large (\Phi\hat{\mathcal{T}}Z_{\theta}(x,a))_i = \sum_{j=0}^{N-1} \left[ 1 - \frac{ | \mathcal{T}z_j |_{V_{min}}^{V_{max}} - z_i }{ \Delta z } \right]_0^1 p_j(x^{\prime},\pi(x^{\prime})) } $$

Good god... and we also have

$$ \hat{\mathcal{T}}z_j := r + \gamma z_j $$

where, to quote the paper:

"for each atom $z_j$, [and] then distribute its probability $ p_j(x^{\prime},\pi(x^{\prime})) $ to the immediate neighbors of $ \hat{\mathcal{T}}z_j $"

I highly recommend reading pg6 in the paper for a fuller explaination. I was originally wondering what the difference was between $\pi$ and simple $\theta$, which the main difference is that $\pi$ is a greedy action selection i.e. we run argmax to get the action.

This was a lot! Luckily they have a re-formalation in algorithmic form:

{% raw %}
def categorical_update(v_min,v_max,n_atoms,support,delta_z,model,reward,gamma,action,next_state):
    t_q=(support*Softmax(model(next_state).gather(action))).sum()
    a_star=torch.argmax(t_q)
    
    m=torch.zeros((N,)) # m_i = 0 where i in 1,...,N-1
    
    for j in range(n_atoms):
        # Compute the projection of $ \hat{\mathcal{T}}z_j $ onto support $ z_j $
        target_z=torch.clamp(reward+gamma*support[:,j],v_min,v_max)
        b_j=(target_z-v_min)/delta_z # b_j in [0,N-1]
        l=torch.floor(b_j)
        u=torch.ceil(b_j)
        # Distribute probability of $ \hat{\mathcal{T}}z_j $
        m[:,l]=m[:,l]+a_star*(u-b)
        m[:,u]=m[:,u]+a_star*(b-l)
    return # Some cross entropy loss
{% endraw %}

There is a small problem with the above equation. This was a (fairly) literal convertion from Algorithm 1 in the paper to Python. There are some problems here:

  • The current setup doesnt handle batches
  • Some of the variables are a little vague
  • Does not handle terminal states

Lets rename these! We will instead have:
$$ m\_i \rightarrow projection\\ a\_star \rightarrow next\_action\\ b\_j \rightarrow support\_value\\ l \rightarrow support\_left\\ u \rightarrow support\_right\\ $$

So lets revise the problem and pretend that we have a 2 action model, batch size of 8, where the last element has a reward of 0, and where left actions are -1, while right actions are 1.

{% raw %}
from torch.distributions.normal import Normal
{% endraw %}

So for a single action we would have a distribution like this...

{% raw %}
plt.plot(Normal(0,1).sample((51,)).numpy())
[<matplotlib.lines.Line2D at 0x7f5995bd9640>]
{% endraw %}

So since our model has 2 actions that it can pick, we create some distributions for them...

{% raw %}
dist_left=torch.vstack([Normal(0.5,1).sample((1,51)),Normal(0.5,0.1).sample((1,51))]).unsqueeze(0)
dist_right=torch.vstack([Normal(0.5,0.1).sample((1,51)),Normal(0.5,1).sample((1,51))]).unsqueeze(0)
(dist_left.shape,dist_right.shape)
(torch.Size([1, 2, 51]), torch.Size([1, 2, 51]))
{% endraw %}

...where the $[1, 2, 51]$ is $[batch, action, n\_atoms]$

{% raw %}
model_out=torch.vstack([copy([dist_left,dist_right][i%2==0]) for i in range(1,9)]).to(device=default_device())
(model_out.shape)
torch.Size([8, 2, 51])
{% endraw %} {% raw %}
summed_model_out=model_out.sum(dim=2);summed_model_out=Softmax(dim=1)(summed_model_out).to(device=default_device())
(summed_model_out.shape,summed_model_out)
(torch.Size([8, 2]),
 tensor([[4.0628e-01, 5.9372e-01],
         [9.4022e-07, 1.0000e+00],
         [4.0628e-01, 5.9372e-01],
         [9.4022e-07, 1.0000e+00],
         [4.0628e-01, 5.9372e-01],
         [9.4022e-07, 1.0000e+00],
         [4.0628e-01, 5.9372e-01],
         [9.4022e-07, 1.0000e+00]], device='cuda:0'))
{% endraw %}

So when we sum/normalize the distrubtions per batch, per action, we get an output that looks like your typical dqn output...

We can also treat this like a regular DQN and do an argmax to get actions like usual...

{% raw %}
actions=torch.argmax(summed_model_out,dim=1).reshape(-1,1).to(device=default_device());actions
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]], device='cuda:0')
{% endraw %} {% raw %}
rewards=actions;rewards
tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]], device='cuda:0')
{% endraw %} {% raw %}
dones=Tensor().new_zeros((8,1)).bool().to(device=default_device());dones[-1][0]=1;dones
tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True]], device='cuda:0')
{% endraw %}

So lets decompose the categorical_update above into something easier to read. First we will note the author's original algorithm:

{% include image.html width="500" height="500" max-width="500" file="/fastrl/docs/images/10e_agents.dqn.categorical_algorithm1.png" %}

We can break this into 3 different functions:

- getting the Q<br>
- calculating the update<br>
- calculating the loss

We will start with the $Q(x_{t+1},a):=\sum_iz_ip_i(x_{t_1},a))$

{% raw %}

class CategoricalDQN[source]

CategoricalDQN(state_sz:int, action_sz:int, n_atoms:int=51, hidden=512, v_min=-10, v_max=10) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

{% endraw %} {% raw %}
{% endraw %}

The CategoricalDQN.q function gets us 90% of the way to the equation above. However, you will notice that that equation is for a specific action. We will handle this in the actual update function.

{% raw %}
dqn=CategoricalDQN(4,2).to(device=default_device())
dqn(torch.randn(8,4).to(device=default_device())).shape
torch.Size([8, 2, 51])
{% endraw %} {% raw %}
dqn.q(torch.randn(8,4).to(device=default_device()))
tensor([[-0.0131,  0.4737],
        [-0.0574,  0.2307],
        [ 0.0965,  0.2352],
        [-0.0143,  0.3457],
        [-0.0161,  0.0787],
        [-0.0951,  0.6884],
        [ 0.2235, -0.0694],
        [-0.0752, -0.0934]], device='cuda:0', grad_fn=<SumBackward1>)
{% endraw %} {% raw %}
dqn.policy(torch.randn(8,4).to(device=default_device()))
tensor([[ 0.0078,  0.0033],
        [ 0.0022,  0.0012],
        [-0.0007, -0.0001],
        [-0.0015, -0.0012],
        [-0.0026,  0.0017],
        [-0.0008, -0.0009],
        [-0.0015, -0.0031],
        [ 0.0043, -0.0018]], device='cuda:0', grad_fn=<MeanBackward1>)
{% endraw %} {% raw %}

distribute[source]

distribute(projection, left, right, support_value, p_a, atom, done)

Does: m_l <- m_l + pj(x{t+1},a*)(u - b_j) operation for non-final states.

{% endraw %} {% raw %}

final_distribute[source]

final_distribute(projection, left, right, support_value, p_a, atom, done)

Does: m_l <- m_l + pj(x{t+1},a*)(u - b_j) operation for final states.

{% endraw %} {% raw %}
{% endraw %} {% raw %}

categorical_update[source]

categorical_update(support, delta_z, q, p, actions, rewards, dones, v_min=-10, v_max=10, n_atoms=51, gamma=0.99, passes=None)

{% endraw %} {% raw %}
{% endraw %} {% raw %}

show_q_distribution[source]

show_q_distribution(cat_dist, title='Update Distributions')

cat_dist being shape: (bs,n_atoms)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
output=categorical_update(dqn.supports,dqn.z_delta,summed_model_out,
                          Softmax(dim=2)(model_out),actions,rewards,dones,passes=None)
show_q_distribution(output)
{% endraw %} {% raw %}
q=dqn.q(torch.randn(8,4).to(device=default_device()))
p=dqn.p(torch.randn(8,4).to(device=default_device()))

output=categorical_update(dqn.supports,dqn.z_delta,q,p,actions,rewards,dones)
show_q_distribution(output,title='Real Model Update Distributions')
{% endraw %} {% raw %}

PartialCrossEntropy[source]

PartialCrossEntropy(p, q)

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class CategoricalDQNTrainer[source]

CategoricalDQNTrainer(n_batch=0, target_sync=300, discount=0.99, n_steps=1) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class CategoricalArgMaxFeed[source]

CategoricalArgMaxFeed() :: AgentCallback

Basic class handling tweaks of a callback loop by changing a obj in various events

{% endraw %} {% raw %}
{% endraw %} {% raw %}
dqn=CategoricalDQN(4,2)

agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect])
source=Source(cbs=[GymLoop('CartPole-v1',agent,steps_count=3,seed=0,
                           steps_delta=1),FirstLast])
dls=SourceDataBlock().dataloaders([source],n=1000,bs=1,num_workers=0)

learn=Learner(dls,agent,loss_func=PartialCrossEntropy,
              cbs=[ExperienceReplay(bs=32,max_sz=100000,warmup_sz=32),CategoricalDQNTrainer(target_sync=300)],
              metrics=[Reward,Epsilon])
{% endraw %} {% raw %}

{% endraw %} {% raw %}
full=True
learn.fit(47 if full else 3,lr=0.0001,wd=0)
epoch train_loss train_reward train_epsilon valid_loss valid_reward valid_epsilon time
0 3.746402 24.020000 0.700000 00:31
1 3.648625 18.780000 0.400000 00:35
2 3.552434 18.270000 0.100000 00:35
3 3.441176 16.150000 0.020000 00:35
4 3.343086 15.280000 0.020000 00:35
5 3.248488 15.050000 0.020000 00:35
6 3.157016 15.840000 0.020000 00:35
7 3.109603 14.760000 0.020000 00:35
8 3.047756 14.790000 0.020000 00:35
9 2.943340 15.610000 0.020000 00:35
10 2.800304 14.780000 0.020000 00:35
11 2.721919 17.700000 0.020000 00:36
12 2.623879 18.980000 0.020000 00:35
13 2.568850 22.410000 0.020000 00:35
14 2.543608 18.490000 0.020000 00:35
15 2.533373 21.960000 0.020000 00:35
16 2.511398 21.420000 0.020000 00:35
17 2.509220 23.690000 0.020000 00:35
18 2.487636 23.810000 0.020000 00:35
19 2.459220 23.400000 0.020000 00:35
20 2.469459 26.730000 0.020000 00:35
21 2.488032 26.320000 0.020000 00:36
22 2.481231 26.650000 0.020000 00:35
23 2.469014 26.620000 0.020000 00:35
24 2.475075 24.760000 0.020000 00:35
25 2.449165 25.340000 0.020000 00:35
26 2.449753 31.470000 0.020000 00:35
27 2.411190 32.450000 0.020000 00:40
28 2.410106 28.650000 0.020000 00:44
29 2.385669 31.490000 0.020000 00:43
30 2.372591 29.360000 0.020000 00:45
31 2.369168 31.480000 0.020000 00:44
32 2.337914 39.750000 0.020000 00:44
33 2.311644 37.850000 0.020000 00:44
34 2.305213 34.390000 0.020000 00:44
35 2.262858 36.700000 0.020000 00:44
36 2.254606 39.650000 0.020000 00:45
37 2.273282 36.010000 0.020000 00:44
38 2.248906 45.780000 0.020000 00:45
39 2.221158 34.290000 0.020000 00:44
40 2.206461 34.270000 0.020000 00:44
41 2.167237 43.120000 0.020000 00:45
42 2.158902 39.920000 0.020000 00:44
43 2.182625 42.020000 0.020000 00:43
44 2.193391 43.800000 0.020000 00:45
45 2.144330 40.500000 0.020000 00:44
46 2.073807 45.150000 0.020000 00:44
{% endraw %} {% raw %}
from IPython.display import HTML
import plotly.express as px
{% endraw %} {% raw %}

show_q[source]

show_q(cat_dist, title='Update Distributions')

cat_dist being shape: (bs,n_atoms)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
learn.cbs[-1].local_pred.shape
torch.Size([32, 51])
{% endraw %} {% raw %}
learn.cbs[-1].local_v.shape
torch.Size([32, 2, 51])
{% endraw %} {% raw %}
show_q(learn.cbs[-1].local_xb[0])
{% endraw %} {% raw %}
show_q(learn.cbs[-1].local_pred)
{% endraw %} {% raw %}
show_q(learn.cbs[-1].local_v[:,1,:])
{% endraw %} {% raw %}
show_q(learn.cbs[-1].local_v[:,0,:])
{% endraw %} {% raw %}
(-learn.cbs[-1].local_pred*learn.cbs[-1].local_xb[0]).sum(dim=1).mean()
TensorBatch(1.9943, device='cuda:0', grad_fn=<AliasBackward>)
{% endraw %} {% raw %}
show_q(-learn.cbs[-1].local_pred*learn.cbs[-1].local_xb[0])
{% endraw %} {% raw %}
from IPython.display import HTML
import plotly.express as px

agent=Agent(dqn,cbs=[CategoricalArgMaxFeed,DiscreteEpsilonRandomSelect(min_epsilon=0.0001,max_epsilon=0.0002,epsilon=0.0002)])
source=Src('CartPole-v1',agent,seed=0,steps_count=1,n_envs=1,steps_delta=1,mode='rgb_array',cbs=[GymSrc,FirstLast])

exp=[o for o,_ in zip(source,range(50))]

fig = px.imshow(torch.vstack([o['image'] for o in exp]).numpy(),animation_frame=0)
HTML(fig.to_html())
{% endraw %} {% raw %}

show_q_and_max_distribution[source]

show_q_and_max_distribution(cat_dist, title='Update Distributions')

cat_dist being shape: (bs,n_atoms)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
show_q_and_max_distribution(dqn.policy(torch.vstack([o['state'] for o in exp]).to(device=default_device())))
{% endraw %}

If you want to run this using multiple processess, the multiprocessing code looks like below. However you will not be able to run this in a notebook, instead add this to a py file and run it from there.

{% include warning.html content='There is a bug in data block that prevents this. Should be a simple fix.' %}